Skip to content

feat(body_axis): add skeleton-agnostic AP inference#945

Draft
khan-u wants to merge 21 commits intoneuroinformatics-unit:mainfrom
khan-u:body-axis-ap-inference
Draft

feat(body_axis): add skeleton-agnostic AP inference#945
khan-u wants to merge 21 commits intoneuroinformatics-unit:mainfrom
khan-u:body-axis-ap-inference

Conversation

@khan-u
Copy link
Copy Markdown

@khan-u khan-u commented Apr 2, 2026

Description

Skeleton-Agnostic Body-Axis Inference Pipeline

This PR draft depends on #875 (also a draft). For review, the work here covers mostly skeleton-agnostic body-axis inference, and its utility for computing polarization is emphasized here:

compute_polarization – Revised API

#875 (comment)

For the latest updates on this draft PR, see the comments below.

body_axis – Revised API

What is this PR

  • Bug fix
  • Addition of a new feature
  • Other

Why is this PR needed?

This PR addresses a practical problem in computing orientation polarization from body-axis keypoints: the user must specify which keypoint pair defines the posterior → anterior axis, i.e., the from_node and to_node.

The AP validation pipeline introduced here automatically assesses that choice and, when the input pair is not well supported, suggests a better alternative.

It leverages the general tendency of animals to move head-first and requires no species-specific anatomical knowledge or keypoint-identity priors, relying only on general body geometry and observed movement.

What does this PR do?

The core question it answers is:

Given a set of keypoints for an animal, which direction is "front" (anterior) and which is "back" (posterior)?

The pipeline is implemented in a new module, movement/kinematics/body_axis.py, which provides:

  • ValidateAPConfig: configuration dataclass for all tunable parameters

  • FrameSelection: dataclass bundling frame indices and segment assignments

  • APNodePairReport: dataclass with detailed AP pair-evaluation results

  • validate_ap(): main validation function for a single individual

  • run_ap_validation(): multi-individual validation entry point

  • The validation is called by compute_polarization() as a side-channel diagnostic when validate_ap=True and body_axis_keypoints is provided.

  • The validation results do not affect the polarization computation itself, but are stored in polarization.attrs["ap_validation_result"] for the user to inspect.

  • Configuration parameters for the various thresholds can be supplied by the user via ap_validation_config.


The pipeline in validate_ap() works through these stages:

1. Tiered validity
  • Frames are classified as:
    • tier-1: ≥ min_valid_frac of keypoints present and ≥ 2 total
    • tier-2: all keypoints present
  • This creates a quality hierarchy:
    • tier-1 is used for motion segmentation (tolerates minor keypoint dropouts)
    • tier-2 is required for skeleton construction and PCA (demands complete observations)
2. Bounding-box centroid computation
  • Arithmetic mean of keypoints:
    • density-biased if keypoints cluster on one body region
  • Alternative:
    • use the midpoint of the axis-aligned bounding box
    • invariant to annotation-density asymmetry
  • A centroid-discrepancy diagnostic computes the normalized distance between the bbox centroid and the arithmetic centroid across tier-1 frames:
    • normalized distance = distance / bbox diagonal
    • reports median / mean / max
    • if median discrepancy > 5%, a warning is issued indicating likely asymmetric annotation density
    • this validates the bbox-centroid choice for that dataset
3. High-motion segment detection
  • Frame-to-frame centroid velocities are computed:
    • valid only when both adjacent frames are tier-1 valid
  • Sliding windows of window_len speed samples, advanced by stride samples, compute median speeds:
    • a window is accepted only if every speed sample within it is valid (non-NaN)
    • high-motion: windows whose median speed meets or exceeds the pct_thresh percentile of all valid-window medians
  • Consecutive qualifying windows form "runs" that must meet a minimum length (min_run_len)
  • Runs are converted to frame ranges and merged if overlapping or abutting
  • This focuses analysis on frames where the animal is actually moving (and thus has informative velocity)
4. Tier-2 filtering on segments
  • Selected segment frames are further filtered to retain only tier-2-valid frames (all keypoints present)
  • A warning is issued if retention falls below 30%
5. Centroid-centered skeleton construction
  • Within the selected high-motion, tier-2 frames, each skeleton is centered on its per-frame bbox centroid:
    • this is the same centroid type used for velocity computation
    • this removes translational variation to yield a shape-only representation
6. Postural clustering
  • Pairwise RMSDs between all centered skeletons are computed and partitioned into:
    • within-segment groups
    • between-segment groups
  • If the between/within variance ratio exceeds postural_var_ratio_thresh and at least 6 frames are available:
    • k-medoids clustering (with silhouette-based model selection) partitions frames into postural clusters
      • clustering is accepted only if the best silhouette score exceeds 0.2
      • otherwise, the pipeline falls back to a global average
      • the primary cluster is the largest by frame count
  • This handles cases where an animal adopts distinct postures (e.g., rearing vs. walking)
  • This ensures the body model comes from a single coherent posture
7. PCA on the average skeleton
  • SVD is performed on the valid (non-NaN) rows of the primary cluster's average centered skeleton
  • This yields:
    • PC1: the longitudinal body axis
    • PC2: the lateral axis
  • A geometric sign convention is applied post-SVD:
    • PC1 is flipped so that PC1[1] >= 0 (y-component non-negative)
    • PC2 is flipped so that PC2[0] >= 0 (x-component non-negative)
  • This ensures that axis orientation is:
    • reproducible across runs
    • decoupled from the anatomical anterior/posterior assignment (determined separately in the next step)
8. Anterior direction inference via velocity voting
  • Centroid velocities are recomputed using only adjacent consecutive frames within the same segment and the same cluster:
    • this prevents spanning gaps or mixing postures
  • These velocity vectors are projected onto PC1:
    • if more projections are positive than negative (strict majority):
      • anterior = +PC1
    • ties default to −PC1
  • The vote margin M = |n₊ − n₋| / (n₊ + n₋) quantifies confidence:
    • 0 = split
    • 1 = unanimous
  • Separately, circular statistics on velocity angles yield the resultant length R = √(C² + S²):
    • C = mean(cos θ)
    • S = mean(sin θ)
    • this measures directional concentration (0 = omnidirectional, 1 = unidirectional)
  • The product R×M is used as a composite quality score in compute_polarization():
    • each individual's R×M is determined solely by its own motion and body shape
    • R×M is independent of the input keypoint pair
    • the best individual is selected by max R×M
  • If the vote margin falls below confidence_floor:
    • the pipeline logs a warning that the anterior assignment is unreliable
  • If multiple clusters exist:
    • inter-cluster agreement on anterior polarity is reported
9. Input AP Node-Pair Filter Cascade
  • Given a candidate keypoint pair (e.g., tail_base → nose), it evaluates quality through:

    • Step I - Lateral alignment filter

      • Computes a combined score for each keypoint:
        • effective_lateral = lateral_offset_norm + lateral_var_weight × lateral_std_norm + longitudinal_var_weight × longitudinal_std_norm.
      • This penalizes keypoints that:
        • are far from the body axis
        • swing side-to-side over time
        • move along the AP axis
      • Keypoints with an effective score above lateral_thresh_pct (default: 50th percentile) are eliminated
      • The adaptive threshold retains roughly half the keypoints
      • This prefers keypoints closest to the body axis and most stable over time
      • Degenerate cases:
        • (a) If all nodes are equally offset (max == min), all normalized offsets are set to 0 and all nodes pass
        • (b) If all nodes are far from the axis but have spread, the nearest still scores 0 and passes
    • Step II - Opposite-sides constraint

      • Surviving keypoints are checked against the AP midpoint:
        • centroid = mean of PC1 projections among valid keypoints
      • Pairs are valid only if their two nodes lie on opposite sides of this midpoint:
        • the product of their signed distances from the midpoint is negative
      • Pairs on the same side cannot span the body axis
    • Step III - Distal/proximal classification

      • Each surviving pair's nodes are classified by their normalized distance from the midpoint:
        • (|pc1_coord − midpoint| / max distance among valid keypoints).
      • A pair is "distal" if:
        • both nodes have normalized midpoint distance above edge_thresh_pct (default: 70th percentile)
      • Otherwise, it is "proximal"
      • The high-percentile threshold preferentially selects body-core extremities (head/tail) over limbs
      • Degenerate case:
        • if all valid nodes are near the midpoint, the most extreme still scores 1.0 and passes
    • Loss diagnostics

      • High Step 1 loss = few axial nodes
      • High Step 2 loss = midpoint poorly separates candidates
      • Low distal fraction = annotation lacks longitudinal spread
10. Suggested Pair
  • The filter cascade identifies a single suggested AP pair using variance-weighted scoring
  • Each candidate pair's AP separation is weighted by the average stability of its two nodes:
    • weighted_sep = separation × (1 − avg_lateral_std),
    • where lateral_std is the normalized standard deviation of each node's lateral offset over time
  • This penalizes high-variance extremity keypoints (e.g., leg tips) in favor of stable body-core keypoints (e.g., thorax, abdomen)
  • If any distal pairs exist, the one with maximum weighted separation is selected (type = "distal")
  • Otherwise, the overall maximum-weighted-separation pair is selected (type = "proximal")
  • The suggested pair is ordered by order_pair_by_ab() so that:
    • element 0 is posterior (lower AP coordinate)
    • element 1 is anterior (higher AP coordinate)
    • this matches the body_axis_keypoints=(from_node, to_node) convention
  • The ordered indices are stored in max_separation_distal_nodes or max_separation_nodes on the APNodePairReport
  • The order check (input_pair_order_matches_inference) compares the input pair's AP coordinates:
    • True if from_node's AP coordinate < to_node's AP coordinate (i.e., from_node is more posterior)
11. Mutually Exclusive Scenarios
  • Classifies the outcome (accept / warn) based on whether the input pair survived all filters, is distal, has maximum separation, etc.
  • See flowchart below

Return: xarray Attribute `ap_validation_result`

When validate_ap=True and body_axis_keypoints is provided, compute_polarization() stores results in polarization.attrs["ap_validation_result"]:

{
    "all_results": [<per-individual result dicts>],
    "best_idx": int  # index into all_results (highest R×M score)
}

Per-Individual Result Dict Fields

Field Type Description
success bool Whether the pipeline completed successfully
anterior_sign int Inferred anterior direction (+1 or -1 relative to PC1)
vote_margin float Confidence in anterior assignment (0-1)
resultant_length float Directional concentration of velocities (0-1)
circ_mean_dir float Circular mean direction angle (radians; present only on success)
num_selected_frames int Tier-2 frames used for inference
num_clusters int Number of postural clusters (1 if no clustering)
primary_cluster int Index of the primary (largest) cluster
PC1 ndarray First principal component vector (2,)
PC2 ndarray Second principal component vector (2,)
avg_skeleton ndarray Average centered skeleton of the primary cluster (n_keypoints, 2)
vel_projs_pc1 ndarray Velocity projections onto PC1 (present only on success)
lateral_std ndarray Per-keypoint std of lateral (PC2) position (present only on success)
longitudinal_std ndarray Per-keypoint std of longitudinal (PC1) position (present only on success)
pair_report dataclass APNodePairReport with detailed AP pair-evaluation
log_lines list[str] Captured diagnostic output (always populated; not printed to stdout when called via compute_polarization(), which hardcodes verbose=False)
error_msg str Error message if the pipeline failed (empty string on success)
individual Hashable Individual name (added by run_ap_validation())

The pair_report field contains scenario (1-13) and outcome ("accept" / "warn") from the flowchart above.


The collective.py API was recently redesigned.

Planned API revision – Newer Comment

Usage (will be deprecated)
from movement.io import load_dataset
from movement.kinematics.collective import compute_polarization

# Load tracking data (must have a 'keypoints' dimension)
ds = load_dataset("tracking.slp", source_software="SLEAP", fps=30)

# Basic: compute body-axis polarization with AP validation
polarization = compute_polarization(
    ds.position,
    body_axis_keypoints=("tail_base", "nose"),
    validate_ap=True,
)

# Validation results are stored in the output's attrs
ap = polarization.attrs["ap_validation_result"]
best = ap["all_results"][ap["best_idx"]]

# Check the inferred anterior direction and confidence
print(f"Anterior sign: {best['anterior_sign']}")   # +1 or -1 relative to PC1
print(f"Vote margin M: {best['vote_margin']:.3f}")  # 0 = split, 1 = unanimous
print(f"Resultant length R: {best['resultant_length']:.3f}")  # directional concentration

# Inspect the pair evaluation
pr = best["pair_report"]
print(f"Scenario: {pr.scenario} ({pr.outcome})")  # e.g. "5 (accept)"
print(f"Input pair order matches inference: {pr.input_pair_order_matches_inference}")

# Check the suggested pair (pipeline-verified posterior → anterior)
if len(pr.max_separation_distal_nodes) > 0:
    suggested = pr.max_separation_distal_nodes  # [posterior_idx, anterior_idx]
    print(f"Suggested distal pair: {suggested}")
elif len(pr.max_separation_nodes) > 0:
    suggested = pr.max_separation_nodes
    print(f"Suggested proximal pair: {suggested}")

# Override config thresholds (any omitted key uses its default)
polarization = compute_polarization(
    ds.position,
    body_axis_keypoints=("tail_base", "nose"),
    validate_ap=True,
    ap_validation_config={
        "lateral_var_weight": 0.5,  # reduce penalty for side-to-side motion
        "confidence_floor": 0.2,    # stricter confidence warning
    },
)

# Disable validation (default behavior)
polarization = compute_polarization(
    ds.position,
    body_axis_keypoints=("tail_base", "nose"),
    validate_ap=False,  # this is the default
)

# Read the diagnostic log (always captured; not printed when called via compute_polarization())
for line in best["log_lines"]:
    print(line)

# Direct access to body_axis module for standalone validation
from movement.kinematics.body_axis import validate_ap, ValidateAPConfig

# Run validation directly on a single individual with custom config
config = ValidateAPConfig(lateral_var_weight=0.5, confidence_floor=0.2)
result = validate_ap(
    ds.position.sel(individuals="mouse1"),
    from_node="tail_base",
    to_node="nose",
    config=config,
    verbose=True,  # prints diagnostic output
)

How has this PR been tested?


1. Using test_body_axis.py.

Comprehensive testing pending implementation of the planned refactor


TestValidateAPConfig (2 tests)

  • Parameter-boundary validation for the ValidateAPConfig dataclass. Tests all 12 configurable fields:

    test_invalid_config_values_raise* (23 parametrized cases)


    - Each field is tested with out-of-range values: - negative fractions - values above 1.0 for [0, 1] fields - zero or negative integers for count fields - floats where integers are required - All must raise `ValueError` with a message matching `"must be"`.

    test_valid_config_does_not_raise
    Constructs a ValidateAPConfig with all fields set to non-default valid values and asserts that no exception is raised.

  • The 12 fields tested are all listed in the config table below.


2. Empirical Validation (SANITY CHECK)

Objective distinction:

  • The grid search here optimized for filter-cascade output:

    • it maximizes how often the suggested AP pair's nodes fall within the GT subset
    • it evaluates threshold tuning
  • The later grid search optimized for core algorithm accuracy:

Evolution note:

  • This section documents the initial self-axis validation on 5 datasets

  • Later comments describe:


The 3-step filter-cascade thresholds and pair-scoring method were empirically optimized via two validation studies on 5 diverse multi-animal datasets (2Flies, 2Mice, 4Gerbils, 5Mice, 2Bees) with hand-curated ground-truth AP node rankings.

Note: Multi-animal datasets have now expanded to 10 total in recent analysis.

Deprecated exhaustive grid search.

Replaced with an early-exit variant motivated by Occam's razor using 10 datasets. See newer comments below.

Analysis 2a: Grid Search over Design and Parameter Space

Find the configuration that maximizes "both nodes in GT" (i.e., the suggested pair contains two ground-truth AP nodes) with correct ordering across all datasets.

Example Script | Detailed Log | Results JSON

Method: Exhaustive grid search over 616,896 configurations testing several method categories:

  • Midpoint:

    • geometric center
    • centroid (mean)
  • Lateral threshold:

    • fixed: 0.3, 0.4, 0.5
    • percentile: 30, 40, 50, 60, 70
  • Edge threshold:

    • fixed: 0.2, 0.3, 0.4, 0.5
    • percentile: 30, 40, 50, 60, 70, 80
  • Normalization:

    • body_width
    • min_max
    • percentile_rank
  • Formula:

    • additive
    • multiplicative
    • RMS
  • Pair scoring:

    • max_separation
    • weighted_variance
    • weighted_both
  • Weights:

    • lateral: 0.0, 0.5, 1.0
    • longitudinal: 0.0, 0.5, 1.0
  Generated 705024 configurations to test
  Estimated evaluations: 705024 configs × 5 datasets = 3,525,120;
  Running with 10 parallel workers...
  Grid search completed in 58.6 seconds

NOTE: When both variance weights are zero, the lateral distance score has no variance penalty terms, so all four ways of combining penalties (additive, multiplicative, max-based, RMS) produce identical results. This means 88,128 of the 705,024 configurations are exact duplicates, leaving 616,896 unique configurations (3,084,480 evaluations across 5 datasets). This clarifies the 705,024 figure reported from the log.

For each configuration:

  • the best individual per dataset was selected via max R×M

  • then the 3-step filter cascade was applied to identify the suggested AP pair

  • results were scored by:

      1. how many datasets achieved both in GT
      1. how many achieved correct ordering

Results + Implementation

Multiple configurations achieved 5/5 datasets with both nodes in GT and correct ordering.

Tied (with many others) for the top-ranked configuration:

Parameter Selected Value
Midpoint centroid (mean of PC1 projections)
Lateral threshold 50th percentile
Edge threshold 70th percentile
Normalization body_width
Formula additive
Pair scoring weighted_variance
Weights lateral=1.0, longitudinal=0.5

LIMITATION: this self-axis evaluation does not test whether the inferred axis generalizes to other individuals. See later comments for shared-axis analysis.

Configuration: (`ValidateAPConfig`)

All configurable thresholds are collected in a single dataclass in movement.kinematics.body_axis. Users pass overrides as a dict via ap_validation_config; any omitted key uses its default. The config below represents the state the pipeline was in during same-axis GT node-pair accuracy evaluation.

Parameter Default Stage/Step Description
min_valid_frac 0.6 1 Minimum fraction of keypoints present for a frame to qualify as tier-1 valid. Must be in [0, 1].
window_len 50 3 Number of speed samples per sliding window for motion detection.
stride 5 3 Step size (in speed samples) between consecutive sliding-window start positions.
pct_thresh 85.0 3 Percentile of valid-window median speeds above which a window is classified as high-motion. Must be in [0, 100].
min_run_len 1 3 Minimum number of consecutive qualifying windows required to form a valid run.
postural_var_ratio_thresh 2.0 6 Between-segment / within-segment RMSD variance ratio above which postural clustering is triggered. Must be positive.
max_clusters 4 6 Upper bound on the number of clusters evaluated during k-medoids (actual upper bound is min(max_clusters, n//2)).
confidence_floor 0.1 8 Vote margin below which the anterior inference is flagged as unreliable. Must be in [0, 1].
lateral_thresh_pct 50.0 9-I Percentile threshold for Step 1 lateral-alignment filtering. Keypoints with an effective lateral score above this percentile are eliminated. Must be in [0, 100].
edge_thresh_pct 70.0 9-III Percentile threshold for Step 3 distal/proximal classification. Pairs for which both nodes have normalized midpoint distance above this percentile are classified as "distal". Must be in [0, 100].
lateral_var_weight 1.0 9-I Weight for the lateral (PC2) position-variance penalty in the combined filtering score. Higher values penalize keypoints with more side-to-side motion. Must be non-negative.
longitudinal_var_weight 0.5 9-I Weight for the longitudinal (PC1) position-variance penalty in the combined filtering score. Higher values penalize keypoints with more AP motion. Must be non-negative.
Per-Dataset Filter Cascade Results (5 datasets, self-axis)

Per-dataset filter-cascade results with the optimal (top-ranked) configuration applied.

These results use self-axis evaluation on best individuals only.

Dataset Step 1: Lateral Filter Step 2: Opposite-Sides Survivors Distal Suggested Pair Type Status
2Flies 7/13 nodes pass 12/21 pairs pass 12 1 [2 → 0] distal Both in GT, correct
2Mice 3/5 nodes pass 2/3 pairs pass 2 0 [3 → 0] proximal Both in GT, correct
4Gerbils 7/14 nodes pass 10/21 pairs pass 10 0 [9 → 5] proximal Both in GT, correct
5Mice 6/11 nodes pass 8/15 pairs pass 8 0 [6 → 1] proximal Both in GT, correct
2Bees 11/21 nodes pass 30/55 pairs pass 30 3 [2 → 1] distal Both in GT, correct

Note:

  • Step 3 classifies surviving pairs as distal or proximal; it does not eliminate any
  • When distal pairs exist, the max-weighted-separation distal pair is suggested
  • Otherwise, the max-weighted-separation proximal pair is suggested
  • The bottom-right panel of the cross-dataset figure visualizes this with stacked bars (distal/proximal)
  • The ★ marker shows which segment was selected

AP Validation Pipeline Overview

  • Top row: shows average skeletons (best individual by R×M within each dataset) with PC1 axes, GT nodes marked, and suggested pairs labeled

  • Bottom-left: shows GT-node coverage per dataset

  • Bottom-right: shows the filter-cascade progression:

    • Steps 1-2 filter pairs

    • Step 3 classifies survivors as:

      • distal (darker, ×× hatch)
      • proximal (lighter, .. hatch)
    • the ★ marker indicates which segment the suggested pair was selected from

    • all 5 datasets show correct AP pair identification (SANITY CHECK)

NOTE: This is a self-axis evaluation on best individuals only.

  • For each of the 5 datasets:

    • the individual with max R×M was selected
    • the filter cascade was applied to that individual's own PCA-derived axis
  • The grid search tested 616,896 configurations, sorted by a 3-tier objective:

      1. count of datasets where both suggested nodes are GT nodes (both_in_gt)
      1. count where at least one is GT (one_in_gt)
      1. count where GT ordering is correct (order_correct)
  • This does not test cross-individual generalization (shared-axis evaluation)

  • Refer to later comments for updated methodology

Example Script | Detailed Log


Analysis 2b: Metric Evaluation for “Best” Individual Selection

Validate that R×M is the best metric for selecting the “reference individual” whose AP ordering others should align with.

Note: All five metrics are computed from each individual's own data:

  • R×M: from that individual's velocity projections
  • PC1 variance ratio: from that individual's skeleton SVD
  • Mean inverse lateral variance: from that individual's keypoint stability
  • Agreement score: from comparing orderings where each individual uses their own axis
  • Skeleton completeness: from that individual's data quality

So the comparison answers:

  • “Which individual-level property best predicts that individual’s own axis will correctly order GT nodes?”

But it cannot answer:

  • “Which property best predicts that individual’s axis will correctly order GT nodes on other individuals’ skeletons?”

Why R×M succeeds (self-axis selection):

  • High R×M → strong unidirectional locomotion
  • → reliable velocity voting
  • → correct anterior sign
  • → correct axis for that individual

What high R×M does not guarantee:

  • that the individual’s body shape is representative
  • that their axis orientation transfers to differently shaped individuals
  • that projecting other skeletons onto this axis preserves ordering

R×M measures inference confidence, not axis universality

  • This analysis validates R×M as the best predictor of self-consistency, not generalizability
  • These are fundamentally different questions requiring distinct evaluation paradigms
  • Framework for generalizability is work in progress.

Example Script | Detailed Log

Method:

  • For each of 5 metrics:

    • select the individual with the highest score per dataset
    • check that individual's ground-truth accuracy (% of AP node pairs correctly ordered vs. hand-curated GT)

Metrics tested:

  • Resultant Length × Vote Margin: composite locomotion-quality score (aka R×M)

  • PC1 variance ratio: ratio of PC1 to PC2 singular values from SVD on the average skeleton

  • Mean inverse lateral variance: average of 1/σ for each keypoint's lateral offset over time

    • rewards stable body-core keypoints
  • Agreement score: fraction of other individuals whose GT-node ordering (projected onto their own AP axis) matches this individual's ordering

  • Skeleton completeness: fraction of keypoints valid (non-NaN) in the average skeleton

Metric 100% Accuracy Mean Accuracy Per-Dataset
R×M 5/5 100.0% 2Flies:✓ 2Mice:✓ 4Gerbils:✓ 5Mice:✓ 2Bees:✓
mean_inv_lateral_var 5/5 100.0% 2Flies:✓ 2Mice:✓ 4Gerbils:✓ 5Mice:✓ 2Bees:✓
agreement_score 4/5 80.0% 2Flies:✓ 2Mice:✓ 4Gerbils:✓ 5Mice:✓ 2Bees:0%
skeleton_completeness 4/5 80.0% 2Flies:✓ 2Mice:✓ 4Gerbils:✓ 5Mice:✓ 2Bees:0%
pc1_variance_ratio 3/5 74.7% 2Flies:✓ 2Mice:✓ 4Gerbils:73% 5Mice:✓ 2Bees:0%

Results:

  • R×M and mean_inv_lateral_var both achieve perfect reference selection (5/5)
  • R×M is preferred because it directly measures locomotion quality (the physical basis for AP inference), rather than an indirect proxy
  • R×M requires no additional computation beyond what is already performed for anterior-direction inference

CAVEAT: "perfect reference selection" here means 100% self-axis GT accuracy.

Detailed Per-Dataset Breakdown (Reference Selection)
4Gerbils (4 individuals):
  Individual      | R×M    | PC1 Var | InvLat  | Agree  | Compl  | GT Acc
  ---------------------------------------------------------------------------
  female          | 0.004  | 3.59    | 0.02    | 0.33   | 1.00   | 100.0%
  pup unshaved    | 0.245  | 4.07    | 0.05    | 0.33   | 1.00   | 100.0%  ← R×M selects
  male            | 0.016  | 6.08    | 0.02    | 0.00   | 1.00   |  73.3%  ← PC1 var would select (wrong)
  pup shaved      | 0.018  | 2.79    | 0.04    | 0.00   | 1.00   |  73.3%

5Mice (5 individuals):
  Individual      | R×M    | PC1 Var | InvLat  | Agree  | Compl  | GT Acc
  ---------------------------------------------------------------------------
  track_0         | 0.843  | 5.48    | 0.06    | 1.00   | 1.00   | 100.0%  ← R×M selects
  track_1         | 0.722  | 4.67    | 0.04    | 1.00   | 1.00   | 100.0%
  track_2         | 0.079  | 3.47    | 0.06    | 1.00   | 1.00   | 100.0%
  track_3         | 0.366  | 5.33    | 0.03    | 1.00   | 1.00   | 100.0%
  track_4         | 0.526  | 4.18    | 0.03    | 1.00   | 1.00   | 100.0%

2Bees (2 individuals):
  Individual      | R×M    | PC1 Var | InvLat  | Agree  | Compl  | GT Acc
  ---------------------------------------------------------------------------
  track_1         | 0.206  | 1.60    | 0.03    | 0.00   | 1.00   | 100.0%  ← R×M selects
  track_0         | 0.004  | 2.12    | 0.02    | 0.00   | 1.00   |   0.0%  ← All others would select (wrong)

The 2Bees case is particularly instructive:

  • track_0 has a higher PC1 variance ratio, higher skeleton completeness, and an equal agreement score, but 0% GT accuracy
  • Only R×M (and mean_inv_lateral_var) correctly identify track_1 as the trustworthy reference

This validates R×M for self-axis selection; shared-axis generalization is explored separately in later comments.

2Bees AP Validation

R×M correctly selects track_1 (100% GT accuracy) over track_0 (0% GT accuracy), despite track_0 having a higher PC1 variance ratio and higher skeleton completeness.

Other Datasets

2Flies (track_0):

2Flies AP Validation

2Mice (track_0):

2Mice AP Validation

4Gerbils (pup_unshaved):

4Gerbils AP Validation

5Mice (track_0):

5Mice AP Validation


Flowchart: Input AP Node-Pair Filter Cascade
  • Survivors:

    • pairs that passed both Step I (lateral alignment) and Step II (opposite-sides constraint)
  • Distal pair:

    • a surviving pair for which both nodes have normalized midpoint distance above the edge_thresh_pct percentile
  • Proximal pair:

    • a surviving pair for which at least one node has normalized midpoint distance below the edge_thresh_pct percentile
  • Max-sep overall:

    • the surviving pair with the largest variance-weighted AP separation among all survivors (distal or proximal)
  • Max-sep distal:

    • the surviving pair with the largest variance-weighted AP separation among distal survivors only
  • Input pair rank:

    • the input pair's rank by variance-weighted separation among all survivors
    • rank 1 = largest weighted separation

AP Node-Pair Filter Cascade Flowchart

STEP I: Lateral Alignment Filter
────────────────────────────────
                [All valid keypoints]
                         |
                         v
          effective_lateral_score <= lateral_thresh_pct?
                       /   \
                     Yes    No --> [Eliminated]
                      |
                      v
               [Candidate nodes]
                      |
                      v
        >= 2 candidates? --No--> [FAIL: Step I]
                      |
                     Yes
                      |
                      v
STEP II: Opposite-Sides Constraint
─────────────────────────────────
          pair on opposite sides of centroid (mean PC1)?
                       /   \
                     Yes    No --> [FAIL: Step II]
                      |
                      v
             [Surviving pairs]       <-- pairs that passed Steps I + II
                      |
                      v
STEP III: Distal/Proximal Classification
───────────────────────────────────────
     both nodes' midline_dist_norm >= edge_thresh_pct?
                       /   \
                     Yes    No
                      |      |
                      v      v
                [Distal] [Proximal]
                      \    /
                       \  /
                        \/
                        |
                        v
SUGGESTED PAIR SELECTION (variance-weighted)
────────────────────────────────────────────
     Any distal pairs among survivors?
            /    \
          Yes     No
           |       |
           v       v
     Max weighted-sep    Max weighted-sep
     distal pair         overall pair
                        |
                        v
SCENARIO ASSIGNMENT (13 mutually exclusive outcomes)
────────────────────────────────────────────────────

Single pair survived Steps I–II?
|
+--Yes--> Input pair == the survivor?
|         |
|         +--Yes--> Survivor is distal?
|         |         |
|         |         +--Yes--> #1 ACCEPT: input pair confirmed (distal)
|         |         +--No---> #2 WARN: input pair is proximal
|         |
|         +--No---> Survivor is distal?
|                   |
|                   +--Yes--> #3 WARN: input pair eliminated, suggest survivor
|                   +--No---> #4 WARN: input pair eliminated, only option is proximal
|
+--No (multiple pairs survived)
          |
          +--> Input pair among survivors?
               |
               +--Yes--> Input pair is distal?
               |         |
               |         +--Yes--> Input pair is max-sep overall?
               |         |         |
               |         |         +--Yes-----------> #5 ACCEPT: input pair is best
               |         |         |
               |         |         +--No--> Input pair is max-sep among distal?
               |         |                  |
               |         |                  +--Yes--> #7 ACCEPT: input pair is best distal
               |         |                  +--No---> #6 WARN: better distal pair exists
               |         |
               |         +--No (input pair is proximal)
               |                   |
               |                   +--> Input pair is max-sep overall?
               |                        |
               |                        +--Yes--> Any distal survivor?
               |                        |         |
               |                        |         +--Yes--> #8 WARN: input proximal, distal alternative exists
               |                        |         +--No---> #9 WARN: input proximal, all survivors proximal
               |                        |
               |                        +--No---> Any distal survivor?
               |                                  |
               |                                  +--Yes--> #10 WARN: input proximal, distal alternative exists
               |                                  +--No---> #11 WARN: input proximal, all survivors proximal
               |
               +--No (input pair not among survivors)
                         |
                         +--> Any distal survivor?
                              |
                              +--Yes--> #12 WARN: input eliminated, suggest max-sep distal
                              +--No---> #13 WARN: input eliminated, suggest max-sep overall

References

Is this a breaking change?

No.

Does this PR require an update to the documentation?

No - API docs auto-generate from docstrings.

Checklist

  • Code tested locally
  • Tests added for new functionality
  • Formatted with pre-commit

Future Refactoring

The body_axis.py module (~2,900 lines) is intentionally monolithic in this PR to simplify review and iteration. Once the API stabilizes, general-purpose functionality could be extracted into existing or new utility modules.

For example:

movement/
├── kinematics/
│   ├── body_axis.py          # Reduced: AP-specific logic only
│   ├── collective.py
│   └── ...
├── utils/
│   ├── vector.py             # + circular_mean, resultant_length (from body_axis)
│   ├── clustering.py         # NEW: kmedoids, silhouette_score (from body_axis)
│   ├── temporal.py           # NEW: detect_runs, merge_segments (from body_axis)
│   └── ...

@khan-u khan-u force-pushed the body-axis-ap-inference branch 2 times, most recently from 5cd79d6 to 01d16a8 Compare April 2, 2026 10:05
@khan-u khan-u marked this pull request as draft April 2, 2026 11:05
@khan-u khan-u changed the title feat(collective): add prior-free body-axis inference WIP: feat(collective): add prior-free body-axis inference for compute_polarization Apr 2, 2026
@khan-u khan-u changed the title WIP: feat(collective): add prior-free body-axis inference for compute_polarization feat(collective): add prior-free body-axis inference for compute_polarization Apr 2, 2026
@khan-u khan-u marked this pull request as ready for review April 2, 2026 13:49
@khan-u khan-u force-pushed the body-axis-ap-inference branch from 01d16a8 to 51866c9 Compare April 4, 2026 05:21
@khan-u khan-u force-pushed the body-axis-ap-inference branch from cbc1c25 to 1bd1618 Compare April 4, 2026 08:05
@khan-u khan-u changed the title feat(collective): add prior-free body-axis inference for compute_polarization feat(body_axis): add prior-free A-P body-axis inference Apr 4, 2026
@khan-u khan-u changed the title feat(body_axis): add prior-free A-P body-axis inference feat(body_axis): add prior-free AP body-axis inference Apr 4, 2026
@khan-u khan-u closed this Apr 5, 2026
@khan-u khan-u reopened this Apr 5, 2026
@khan-u khan-u force-pushed the body-axis-ap-inference branch 5 times, most recently from c78dce3 to 615666c Compare April 5, 2026 09:48
@khan-u khan-u force-pushed the body-axis-ap-inference branch from 358e817 to b1df3b9 Compare April 5, 2026 09:55
@khan-u
Copy link
Copy Markdown
Author

khan-u commented Apr 7, 2026

Self-axis grid search workflow for tuning default thresholds of 3-step filter cascade

assumes only that keypoints exhibit lateral alignment about the body axis

Datasets

https://dreem.sleap.ai/0.5.1/datasets
https://legacy.sleap.ai/datasets.html#datasets

Species Datasets Individuals
Fly 5 2, 4, 4, 8, 8
Mouse 3 2, 2, 5
Gerbil 1 4
Bee 1 2
Total 10 41
Self-Axis Workflow (redesigned as shared-axis in next comment)
├── Phase 0: Precompute per-dataset PCA data 
│   │
│   ├── discover_datasets()
│   │   └── Scan SLP_DIR for .slp files with matching GROUND_TRUTH entries
│   │       └── 10 datasets: 5× fly, 3× mouse, 1× gerbil, 1× bee
│   │
│   ├── For each dataset:
│   │   │
│   │   ├── compute_best_individual_with_pca()
│   │   │   ├── Run validate_ap() per individual using GT pair
│   │   │   ├── Compute R×M per individual
│   │   │   ├── Select individual with max R×M
│   │   │   └── Cache: avg_skeleton, PC1, PC2, anterior_sign,
│   │   │             lateral_std, longitudinal_std
│   │   │
│   │   └── precompute_from_cached_validation()
│   │       ├── Project avg_skeleton onto PC1 → pc1_coords
│   │       ├── Project avg_skeleton onto PC2 → lateral_offsets (absolute)
│   │       ├── Apply anterior_sign → ap_coords (signed PC1)
│   │       ├── Normalize lateral_std → lat_std_norm (/ max)
│   │       ├── Normalize longitudinal_std → long_std_norm (/ max)
│   │       ├── Compute valid_mask (non-NaN keypoints)
│   │       └── Returns: PrecomputedData (unchanged)
│   │
│   └── Result: dict[dataset_name → PrecomputedData]
│
│
├── Phase 1: Generate config grid (degenerate elimination)
│   │
│   ├── generate_param_grid()
│   │
│   ├── For each (lateral_var_weight, longitudinal_var_weight) combo:
│   │   │
│   │   ├── n_active_weights = 0 (both weights == 0):
│   │   │   └── score_formula is irrelevant (all reduce to d)
│   │   │       └── Fix to single formula (additive)
│   │   │       └── Effective grid: 6 × 24 × 17 × 3 × 1 × 4 = 29,376
│   │   │
│   │   └── n_active_weights > 0 (at least one weight > 0):
│   │       └── score_formula matters → enumerate all 4
│   │           └── Effective grid per weight combo: 6 × 24 × 17 × 3 × 4 × 4 = 117,504
│   │
│   ├── Total configs:
│   │   ├── n_active_weights = 0:   29,376   (1 weight combo × 1 formula)
│   │   ├── n_active_weights = 1:  235,008   (2 weight combos × 4 formulas)
│   │   ├── n_active_weights = 2:  352,512   (3 weight combos × 4 formulas)
│   │   └── Grand total:   616,896   
│   │
│   └── Partition by n_active_weights into three buckets
│
│
├── Phase 2: Search (priority cascade)
│   │
│   ├── n_active_weights ← 0
│   │
│   └── LOOP:
│       │
│       ├── Get all configs with this n_active_weights
│       │
│       ├── Compute gt_ordering_accuracy for each config (parallel, N workers)
│       │   │
│       │   └── Per config: compute_gt_ordering_accuracy(config, precomputed_data)
│       │       │
│       │       ├── For each dataset:
│       │       │   │
│       │       │   ├── run_validation_fast(precomputed, config)
│       │       │   │   │
│       │       │   │   ├── Step A: Normalize lateral offsets 
│       │       │   │   │   └── normalize_d(lateral_offsets, valid_mask, config.norm_method)
│       │       │   │   │
│       │       │   │   ├── Step B: Compute midpoint along AP axis
│       │       │   │   │   └── compute_midpoint(pc1_coords, ..., config.midpoint_method)
│       │       │   │   │
│       │       │   │   ├── Step C: Lateral filter (Step 1 of cascade)
│       │       │   │   │   └── apply_lateral_filter_custom(...)
│       │       │   │   │       ├── Compute effective score per keypoint
│       │       │   │   │       ├── Compute effective threshold (fixed method)
│       │       │   │   │       ├── Apply lateral_filter_method at lateral_filter_param
│       │       │   │   │       └── → candidates[] (keypoint indices surviving filter)
│       │       │   │   │           └── Early exit if < 2 candidates
│       │       │   │   │
│       │       │   │   ├── Step D: Edge filter (Steps 2–3 of cascade)
│       │       │   │   │   └── apply_edge_filter_custom(...)
│       │       │   │   │       ├── Find opposite-side pairs (candidates straddling midpoint)
│       │       │   │   │       ├── Apply edge_filter_method at edge_filter_param
│       │       │   │   │       ├── Classify: distal_pairs[] vs proximal_pairs[]
│       │       │   │   │       └── → pair arrays
│       │       │   │   │           └── Early exit if no pairs survive
│       │       │   │   │
│       │       │   │   ├── Step E: Score and select best pair
│       │       │   │   │   └── select_best_pair(...)
│       │       │   │   │       ├── Prefer distal_pairs; fall back to proximal_pairs
│       │       │   │   │       ├── Score each pair via pair_scoring_method
│       │       │   │   │       └── → (posterior, anterior, pair_type, score, margin)
│       │       │   │   │
│       │       │   │   └── Returns: ValidationResult 
│       │       │   │       ├── candidates: np.ndarray       
│       │       │   │       ├── suggested_posterior, suggested_anterior
│       │       │   │       ├── pair_type, both_in_gt, order_correct
│       │       │   │       └── n_step1_candidates, n_step2_pairs, ...
│       │       │   │
│       │       │   └── Evaluate GT ordering among surviving candidates 
│       │       │       │
│       │       │       ├── gt_in_candidates = GT nodes ∩ candidates[]
│       │       │       │
│       │       │       ├── For each GT pair (i, j) where i more posterior:
│       │       │       │   ├── Both i and j in candidates?
│       │       │       │   │   ├── YES → count as checkable pair
│       │       │       │   │   │   └── correct if ap_coords[j] > ap_coords[i]
│       │       │       │   │   └── NO → pair is uncheckable
│       │       │       │   │       └── (GT node filtered out = config too aggressive)
│       │       │       │
│       │       │       └── Per dataset: (correct, checkable, total_gt_pairs)
│       │       │
│       │       ├── Aggregate by species:
│       │       │   └── Per species: sum(correct) / sum(total_gt_pairs)
│       │       │       │
│       │       │       └── NOTE: denominator is total_gt_pairs, not checkable.
│       │       │           A filtered-out GT node means the pair counts as
│       │       │           INCORRECT, not as excluded. This penalizes
│       │       │           overly aggressive configs.
│       │       │
│       │       └── gt_ordering_accuracy = mean of per-species accuracies
│       │           └── 100% requires: every GT pair in every dataset
│       │               has both nodes in candidates AND correct ordering
│       │
│       ├── Filter to configs with gt_ordering_accuracy = 100%
│       │
│       ├── None left?
│       │   ├── n_active_weights < 2 → n_active_weights++, LOOP
│       │   └── n_active_weights = 2 → FAIL (no config achieves 100%)
│       │
│       ├── Exactly 1 left?
│       │   └── EXIT winner
│       │
│       └── Multiple left → tiebreak cascade:
│           │
│           ├── Rank by param_sensitivity (keep lowest)
│           │   └── param_sensitivity = lateral_sens[method] + edge_sens[method]
│           │
│           ├── Single survivor?
│           │   └── EXIT winner
│           │
│           └── Still tied → timing phase
│               ├── For each tied config:
│               │   ├── 1 warmup run (all datasets)
│               │   └── 100 timed runs (sequential, perf_counter, ms)
│               │       └── Mean time across runs
│               └── EXIT fastest

Handling uncheckable pairs

When a GT node is filtered out by the lateral filter, any GT pair
containing that node becomes uncheckable. 

  Option 1: Exclude from denominator (lenient)
      A config that filters aggressively gets a free pass
      on pairs it can't check. Could achieve 100% by
      filtering out all but two GT nodes.

  Option 2: Count as incorrect (strict) 
      A config must preserve ALL GT nodes through the
      lateral filter AND order them correctly.
      Penalizes overfiltering and underfiltering equally.

Choose Option 2 because our goal is to find a config whose
filter cascade is compatible with the ground truth ordering.
A config that destroys GT information is not "correct by default."

Result:

Log: /Users/ukhan/Desktop/ss/movement-fork/PR-assets/datasets/multi-animal/exports/grid_search/logs/find_optimal_20260406_234842.log
Started: 2026-04-06T23:48:42.169053

Discovering datasets...
Found 10 datasets

Ground truth AP node rankings (posterior to anterior):
  2Mice-A: tb -> snout
  4Gerbils: spine5 -> spine4 -> spine3 -> spine2 -> spine1 -> nose
  8Flies-A: abdomen -> thorax -> head
  8Flies-B: abdomen -> thorax -> head
  4Flies-B: abdomen -> thorax -> head
  2Mice-B: tb -> snout
  5Mice: tail_base -> neck -> nose
  2Bees: abdo -> thor -> head
  2Flies: abdomen -> thorax -> head
  4Flies-A: abdomen -> thorax -> head

Loading datasets...
Computing R×M independently for all datasets...
Computing best individuals via R×M selection...

....

Generating configs...
Total configs: 616896
  n_active_weights=0: 29376
  n_active_weights=1: 235008
  n_active_weights=2: 352512
Parallel workers: 10


[n_active_weights=0] Testing 29376 configs (parallel)
-> configs with 100% gt_ordering_accuracy: 13872
-> Rank by param_sensitivity: min=1, tied=1872
-> Still tied: timing 1872 configs (100 runs each, sequential)...
-> EXIT fastest (0.981ms)


...

Finished: 2026-04-06T23:54:45.134059
Parameter Selected Value
Midpoint Geometric center
Lateral filter Z-score (param=0.5)
Edge filter Adaptive min k (param=1)
Normalization body width
Formula additive
Pair scoring max separation
Weights lateral=0, longitudinal=0

NOTE: Optimized for self-axis evaluation only - see limitations toward 'cross-species/datasets' application in the next comment.

find_optimal_20260406_234842.log

@khan-u khan-u changed the title feat(body_axis): add prior-free AP body-axis inference feat(body_axis): add skeleton-agnostic AP inference Apr 7, 2026
@khan-u
Copy link
Copy Markdown
Author

khan-u commented Apr 7, 2026

The self-axis grid search architecture optimizes thresholds for an evaluation that doesn't test generalization

  • The grid search tested 29,376 filter cascade configurations at n_active_weights=0 (variance penalties disabled) to find thresholds achieving 100% GT ordering accuracy.1
  • This evaluation tests only the best individual per dataset on their own axis.
  • It does not test whether the same thresholds correctly order GT nodes for other individuals in that dataset when projected onto the best individual's axis (the actual use case).

A separate validation script applied the grid-search-winning thresholds to all 41 individuals. The self-axis evaluation (Pass 3) confirmed 100% accuracy for best individuals (SANITY CHECK):

Pass 3: GT Ordering Accuracy (GT nodes only)
  Metric: For each individual, project GT nodes onto inferred AP axis,
          count correctly-ordered pairs.
  This measures ORDERING CORRECTNESS of the core algorithm (PCA + velocity voting).
  100% = all GT node pairs ordered correctly. Independent of filter cascade.
  ...
  GT Ordering Accuracy Summary (best individual per dataset):
    bee: 3/3 = 100.0%
    fly: 15/15 = 100.0%
    gerbil: 15/15 = 100.0%
    mouse: 5/5 = 100.0%
  Overall GT ordering accuracy (species-averaged): 100.0%

However, when projecting each individual onto their dataset's reference axis (derived from that dataset's best individual), accuracy dropped substantially:

SHARED-AXIS GT ORDERING ANALYSIS
  Each individual's skeleton projected onto best individual's AP axis.
  Accuracy = correctly ordered GT node pairs / total GT pairs.

  4Gerbils (reference: pup unshaved):
    female: 3/15 pairs = 20.0% (R×M=0.004, angle=22.2°)
    male: 4/15 pairs = 26.7% (R×M=0.016, angle=40.7°)
    pup shaved: 5/15 pairs = 33.3% (R×M=0.018, angle=57.3°)
    *pup unshaved*: 10/15 pairs = 66.7% (R×M=0.245, angle=71.6°)

  2Flies (reference: track_0):
    *track_0*: 1/3 pairs = 33.3% (R×M=0.024, angle=69.1°)
    track_1: 0/3 pairs = 0.0% (R×M=0.021, angle=62.5°)

  2Bees (reference: track_1):
    track_0: 0/3 pairs = 0.0% (R×M=0.004, angle=10.4°)
    *track_1*: 1/3 pairs = 33.3% (R×M=0.206, angle=89.6°)
SUMMARY (best individuals only)

  Per-species accuracy:
    bee: 1/3 = 33.3%
    fly: 10/15 = 66.7%
    gerbil: 10/15 = 66.7%
    mouse: 5/5 = 100.0%

  Overall GT ordering accuracy (species-averaged): 66.7%

ap_validation_20260407_002942.log

Footnotes

  1. Side note: the grid search prioritizes n_active_weights=0 first because when both variance weights are zero, all four score_formula options (additive, multiplicative, max-based, RMS) produce identical scores - the only structural degeneracy in the parameter space. This reduces the search space by 88,128 configs (12.5%). It is a computational optimization, not a principled claim that variance penalties add complexity/reduce simplicity.

@khan-u
Copy link
Copy Markdown
Author

khan-u commented Apr 7, 2026

Grid-Search for Shared-Axis Evaluations

Motivation

The filter cascade thresholds must satisfy a robustness constraint:

  • given any individual's skeleton projected onto a shared reference axis, the cascade must:
  • preserve all GT nodes through filtering + order them correctly on that axis.

This constraint is non-trivial because individuals within a dataset exhibit:

Effect on Shared-Axis Projections

  • Geometric variability - different mean skeletal poses yield different lateral offset distributions
  • Axis misalignment - an individual's intrinsic AP axis may diverge substantially from the reference axis
  • Variance heterogeneity - per-keypoint variance structures differ, affecting score normalization

Self-axis evaluation (demonstrated above) sidesteps this variability entirely:

  • Each individual defines its own axis, so projection is onto that individual's own principal components
  • Variance statistics are computed from that individual's own data.
  • Thresholds tuned under these conditions may be too strict or too lenient when applied to other individuals.

Shared-axis evaluation (current goal) confronts this variability directly:

  • All 41 individuals across 10 datasets are projected onto their respective dataset's reference axis (the individual with highest R×M).
  • A config achieves 100% GT ordering accuracy only if the thresholds correctly handle the full spectrum of inter-individual variability present in the evaluation corpus.

Shared-Axis Workflow

├── Phase 0: Precompute per-dataset PCA data (shared-axis approach)
│   │
│   ├── discover_datasets()
│   │   └── Scan SLP_DIR for .slp files with matching GROUND_TRUTH entries
│   │       └── 10 datasets: 5× fly, 3× mouse, 1× gerbil, 1× bee
│   │
│   ├── For each dataset:
│   │   │
│   │   ├── compute_best_individual_with_pca()
│   │   │   ├── Run validate_ap() per individual using GT pair
│   │   │   ├── Compute R×M per individual
│   │   │   ├── Select individual with max R×M → REFERENCE individual
│   │   │   └── Return ALL individuals' val_results
│   │   │       └── Each val_result contains: avg_skeleton, PC1, PC2,
│   │   │           anterior_sign, lateral_std, longitudinal_std
│   │   │
│   │   ├── For EACH individual in dataset:
│   │   │   │
│   │   │   └── precompute_shared_axis_projection()
│   │   │       ├── Get individual's val_result (avg_skeleton, lateral_std, longitudinal_std)
│   │   │       ├── Project onto REFERENCE's PC1 → pc1_coords
│   │   │       ├── Project onto REFERENCE's PC2 → lateral_offsets (absolute)
│   │   │       ├── Apply REFERENCE's anterior_sign → ap_coords (signed PC1)
│   │   │       ├── Use INDIVIDUAL's lateral_std, longitudinal_std (normalized)
│   │   │       ├── Compute valid_mask (non-NaN keypoints)
│   │   │       └── Returns: PrecomputedData (with individual_name field)
│   │   │
│   │   └── Result per dataset: list[PrecomputedData]
│   │
│   └── Result: dict[dataset_name → list[PrecomputedData]]
│
├── Phase 1: Generate config grid (degenerate elimination)
│   │
│   ├── generate_param_grid()
│   │
│   ├── For each (lateral_var_weight, longitudinal_var_weight) combo:
│   │   │
│   │   ├── n_active_weights = 0 (both weights == 0):
│   │   │   └── score_formula is irrelevant (all reduce to d)
│   │   │       ├── Fix to single formula (additive)
│   │   │       └── Effective grid: 6 × 24 × 17 × 3 × 1 × 4 = 29,376
│   │   │
│   │   └── n_active_weights > 0 (at least one weight > 0):
│   │       └── score_formula matters → enumerate all 4
│   │           └── Effective grid per weight combo: 6 × 24 × 17 × 3 × 4 × 4 = 117,504
│   │
│   ├── Total configs:
│   │   ├── n_active_weights = 0:   29,376
│   │   ├── n_active_weights = 1:  235,008
│   │   ├── n_active_weights = 2:  352,512
│   │   └── Grand total:   616,896  
│   │
│   └── Partition by n_active_weights into three buckets
│
└── Phase 2: Search (priority cascade)
    │
    ├── n_active_weights ← 0
    │
    └── LOOP:
        │
        ├── Get all configs with this n_active_weights
        │
        ├── Compute gt_ordering_accuracy for each config (parallel, N workers)
        │   │
        │   └── Per config: compute_gt_ordering_accuracy(config, precomputed_data)
        │       │
        │       ├── For each dataset:
        │       │   │
        │       │   └── For each individual's PrecomputedData:
        │       │       │
        │       │       ├── run_validation_fast(precomputed, config)
        │       │       │   │
        │       │       │   ├── Step A: Normalize lateral offsets
        │       │       │   │   └── normalize_d(lateral_offsets, valid_mask, config.norm_method)
        │       │       │   │
        │       │       │   ├── Step B: Compute midpoint along AP axis
        │       │       │   │   └── compute_midpoint(pc1_coords, ..., config.midpoint_method)
        │       │       │   │
        │       │       │   ├── Step C: Lateral filter (Step 1 of cascade)
        │       │       │   │   └── apply_lateral_filter_custom(...)
        │       │       │   │       ├── Compute effective score per keypoint
        │       │       │   │       ├── Compute effective threshold (fixed method)
        │       │       │   │       ├── Apply lateral_filter_method at lateral_filter_param
        │       │       │   │       └── → candidates[] (keypoint indices surviving filter)
        │       │       │   │           └── Early exit if < 2 candidates
        │       │       │   │
        │       │       │   ├── Step D: Edge filter (Steps 2–3 of cascade)
        │       │       │   │   └── apply_edge_filter_custom(...)
        │       │       │   │       ├── Find opposite-side pairs (candidates straddling midpoint)
        │       │       │   │       ├── Apply edge_filter_method at edge_filter_param
        │       │       │   │       ├── Classify: distal_pairs[] vs proximal_pairs[]
        │       │       │   │       └── → pair arrays
        │       │       │   │           └── Early exit if no pairs survive
        │       │       │   │
        │       │       │   ├── Step E: Score and select best pair
        │       │       │   │   └── select_best_pair(...)
        │       │       │   │       ├── Prefer distal_pairs; fall back to proximal_pairs
        │       │       │   │       ├── Score each pair via pair_scoring_method
        │       │       │   │       └── → (posterior, anterior, pair_type, score, margin)
        │       │       │   │
        │       │       │   └── Returns: ValidationResult
        │       │       │       ├── candidates: np.ndarray
        │       │       │       ├── suggested_posterior, suggested_anterior
        │       │       │       ├── pair_type, both_in_gt, order_correct
        │       │       │       └── n_step1_candidates, n_step2_pairs, ...
        │       │       │
        │       │       └── Evaluate GT ordering among surviving candidates
        │       │           │
        │       │           ├── gt_in_candidates = GT nodes ∩ candidates[]
        │       │           │
        │       │           ├── For each GT pair (i, j) where i more posterior:
        │       │           │   ├── Both i and j in candidates?
        │       │           │   │   ├── YES → count as checkable pair
        │       │           │   │   │   └── correct if ap_coords[j] > ap_coords[i]
        │       │           │   │   └── NO → pair is uncheckable
        │       │           │   │       └── (GT node filtered out = config too aggressive)
        │       │           │
        │       │           └── Per individual: (correct, checkable, total_gt_pairs)
        │       │
        │       ├── Aggregate by species:
        │       │   └── Per species:
        │       │       sum(correct across all individuals) /
        │       │       sum(total_gt_pairs across all individuals)
        │       │
        │       └── gt_ordering_accuracy = mean of per-species accuracies
        │
        ├── Filter to configs with gt_ordering_accuracy = 100%
        │
        ├── None left?
        │   ├── n_active_weights < 2 → n_active_weights++, LOOP
        │   └── n_active_weights = 2 → FAIL
        │
        ├── Exactly 1 left?
        │   └── EXIT winner
        │
        └── Multiple left → tiebreak cascade:
            │
            ├── Rank by param_sensitivity (keep lowest)
            │
            ├── Single survivor?
            │   └── EXIT winner
            │
            └── Still tied → timing phase (time_config)
                ├── 1 warmup run
                └── 100 timed runs → select fastest

Handling uncheckable pairs

When a GT node is filtered out by the lateral filter, any GT pair
containing that node becomes uncheckable.

  Option 1: Exclude from denominator (lenient)
      A config that filters aggressively gets a free pass
      on pairs it can't check. Could achieve 100% by
      filtering out all but two GT nodes.

  Option 2: Count as incorrect (strict)
      A config must preserve ALL GT nodes through the
      lateral filter AND order them correctly.
      Penalizes overfiltering and underfiltering equally.

Choose Option 2 because our goal is to find a config whose
filter cascade is compatible with the ground truth ordering.
A config that destroys GT information is not "correct by default."

NOTE: This is evaluated for EACH individual projected onto
the dataset's shared (reference) axis, not just the reference
individual itself. A config must work across all individuals.

TO DO:

Run shared-axis grid search to identify configs achieving ideally 100% GT ordering accuracy across all 41 individuals. True validation: verify the winning config correctly orders Step 1 survivors outside the GT set. GT nodes informed config selection, so correct ordering of non-GT survivors is out-of-sample proof that the cascade captures genuine AP structure--not artifacts of the GT subset--and generalizes across individuals projected onto each dataset's reference (best R×M) axis.

@khan-u
Copy link
Copy Markdown
Author

khan-u commented Apr 7, 2026

Revised body_axis API

Revised collective.py complementing these changes

Revised API – Newer Comment

body_axis.py 
│
├── infer_body_axis(ds, individual=None)
│   │   Auto-detect (posterior, anterior) pair
│   │
│   ├── individual: str | None
│   │   ├── Default: None → best individual (max R×M)
│   │   └── Override → Reject: name not in dataset individuals
│   │
│   ├── Reject: dataset has < 2 keypoints
│   │
│   ├── Runs: full AP pipeline on selected individual
│   │   │
│   │   ├── prepare_validation_inputs(data)
│   │   │   └── → position array, keypoint names, n_frames, n_keypoints
│   │   │
│   │   ├── run_motion_segmentation(pos_np, config)
│   │   │   ├── compute_tiered_validity() → tier1_valid, tier2_valid
│   │   │   ├── compute_bbox_centroid() → centroids, discrepancy
│   │   │   ├── compute_frame_velocities() → velocities, speeds
│   │   │   └── detect_motion_segments() → segments
│   │   │       ├── compute_sliding_window_medians()
│   │   │       ├── detect_high_motion_windows()
│   │   │       ├── detect_runs()
│   │   │       ├── convert_runs_to_segments()
│   │   │       └── merge_segments()
│   │   │
│   │   ├── select_tier2_frames(segments, tier2_valid, config)
│   │   │   └── → selected_indices (100% keypoint visibility)
│   │   │
│   │   ├── decide_and_run_clustering(pos_np, selected_indices, centroids, config)
│   │   │   ├── build_centered_skeletons()
│   │   │   ├── compute_postural_variance_ratio()
│   │   │   │   └── ratio > threshold? → perform_postural_clustering()
│   │   │   │       ├── kmedoids()
│   │   │   │       └── silhouette_score()
│   │   │   └── compute_cluster_pca_and_anterior()
│   │   │       ├── SVD → PC1, PC2
│   │   │       ├── compute_cluster_velocities()
│   │   │       └── infer_anterior_from_velocities()
│   │   │       └── → PC1, PC2, anterior_sign, vote_margin, R×M, avg_skeleton
│   │   │
│   │   └── evaluate_ap_node_pair(avg_skeleton, PC1, anterior_sign)
│   │       │   3-step filter cascade → score all candidate pairs
│   │       │
│   │       ├── Step 1: compute_node_projections → apply_lateral_filter
│   │       │   └── Keep nodes close to body axis → sorted_candidate_nodes
│   │       │
│   │       ├── Step 2: find_opposite_side_pairs
│   │       │   └── Pairs on opposite sides of midpoint → valid_pairs
│   │       │
│   │       └── Step 3: classify_distal_proximal
│   │           └── → distal_pairs, proximal_pairs
│   │           └── max_separation_distal_nodes (top pair)
│   │
│   └── Returns: InferenceResult
│       ├── posterior_keypoint, anterior_keypoint
│       ├── confidence_score (R×M)
│       ├── pc1_vector: np.ndarray
│       ├── anterior_sign: int (+1 or −1)
│       └── sorted_candidate_nodes[]
│           └── Consumed by heading auto-detect path
│
│
├── validate_ap_pair(ds, posterior, anterior, individual=None)
│   │   Validate a user-provided (posterior, anterior) pair
│   │
│   ├── Reject: either name not in dataset keypoints
│   ├── Reject: posterior == anterior
│   │
│   ├── individual: str | None
│   │   ├── Default: None → best individual (max R×M)
│   │   └── Override → Reject: name not in dataset individuals
│   │
│   ├── Runs: same AP pipeline as infer_body_axis on selected individual
│   │   └── evaluate_ap_node_pair called with user's (posterior, anterior)
│   │       └── APNodePairReport:
│   │           ├── input_pair_in_candidates
│   │           ├── input_pair_opposite_sides
│   │           ├── input_pair_is_distal
│   │           └── input_pair_order_matches_inference
│   │
│   └── Returns: ValidationReport
│       ├── is_valid: bool  ← input_pair_order_matches_inference
│       └── pair_report: APNodePairReport
│
│
└── project_keypoints(ds, axis_result, individuals=None)
    │   Project all keypoints onto inferred AP axis
    │
    ├── axis_result: InferenceResult
    │   └── Reject: invalid axis_result
    │
    ├── individuals: list[str] | None
    │   ├── Default: None → all individuals in dataset
    │   └── Override → Reject: any name not in dataset
    │
    └── Returns: xr.DataArray of AP coordinates (time × keypoints × individuals)

TO DO:

GT nodes tuned the filter. Non-GT nodes will test whether it learned or memorized. So, if Step 1 survivors outside GT land in AP positions manually verified as anatomically correct, the cascade discovered structure; if not, it overfit to the GT subset.

@sonarqubecloud
Copy link
Copy Markdown

sonarqubecloud Bot commented Apr 8, 2026

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

1 participant